#_*_coding:utf-8_*_
import torch
import torch.nn as nn
import torchvision.datasets as normal_datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from data_process import read_image
num_epochs = 5
batch_size = 100
learning_rate = 0.001
 
writer = SummaryWriter('/home/newsun/torch_numrec/runs/exp2')
torch.set_default_tensor_type(torch.DoubleTensor)
# 将数据处理成Variable, 如果有GPU, 可以转成cuda形式
def get_variable(x):
    x = Variable(x)
    return x.cuda() if torch.cuda.is_available() else x
 
 
# # 从torchvision.datasets中加载一些常用数据集
# train_dataset = normal_datasets.MNIST(
#     root='./data/',  # 数据集保存路径
#     train=True,  # 是否作为训练集
#     transform=transforms.ToTensor(),  # 数据如何处理, 可以自己自定义
#     download=True)  # 路径下没有的话, 可以下载
 
# # 见数据加载器和batch
# test_dataset = normal_datasets.MNIST(root='./data/',
#                                      train=False,
#                                      transform=transforms.ToTensor())
 
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
#                                            batch_size=batch_size,
#                                            shuffle=True)
 
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                           batch_size=batch_size,
#                                           shuffle=False)

 

# 两层卷积
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 使用序列工具快速构建
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(4 * 4 * 32, 10)
 
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = out.view(out.size(0), -1)  # reshape
        out = self.fc(out)
        return out

if __name__ == "__main__" :
    train_loader, test_loader, n_class, train_dataset, test_dataset = read_image('/home/newsun/knntrain')
    cnn = CNN()

    if torch.cuda.is_available():
        cnn = cnn.cuda()
    
    # 选择损失函数和优化方法
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        cnn.train()
        for i, (images, labels) in enumerate(train_loader):
            images = get_variable(images)
            labels = get_variable(labels)
    
            outputs = cnn(images)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            writer.add_scalar('loss', loss.item(), global_step = epoch*(len(train_dataset) // batch_size)+i)
            if (i + 1) % 1 == 0:
                print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
                    % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.item()))
        cnn.eval()
        testing_correct = 0
        for i, (images, labels) in enumerate(test_loader):
            images = get_variable(images)
            labels = get_variable(labels)
            outputs = cnn(images)
            _,pred = torch.max(outputs, 1)
            # print("outputs:")
            # print(outputs)
            # print("pred:")
            # print(pred)
            testing_correct += torch.sum(pred == labels.data)
        print("Test accuracy is :{:.4f}".format(100 * testing_correct / len(test_dataset)))
    
    # writer.add_graph(cnn)

    # 保存
    torch.save(cnn.state_dict(), "cnn16.pt")
    
    